{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp models.TransformerModel"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TransformerModel\n",
    "\n",
    "> This is an unofficial PyTorch implementation by Ignacio Oguiza - oguiza@gmail.com based on:\n",
    "* Rußwurm, M., Pelletier, C., Zollner, M., Lefèvre, S., & Körner, M. (2019). BreizhCrops: A Time Series Dataset for Crop Type Mapping. arXiv preprint arXiv:1905.11893.\n",
    "* Official implementation: https://github.com/dl4sits/BreizhCrops/blob/master/breizhcrops/models/TransformerModel.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from tsai.imports import *\n",
    "from tsai.models.layers import *\n",
    "from tsai.models.utils import *\n",
    "from torch.nn.modules.transformer import TransformerEncoder, TransformerEncoderLayer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class TransformerModel(Module):\n",
    "    def __init__(self, c_in, c_out, d_model=64, n_head=1, d_ffn=128, dropout=0.1, activation=\"relu\", n_layers=1):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            c_in: the number of features (aka variables, dimensions, channels) in the time series dataset\n",
    "            c_out: the number of target classes\n",
    "            d_model: total dimension of the model.\n",
    "            nhead:  parallel attention heads.\n",
    "            d_ffn: the dimension of the feedforward network model.\n",
    "            dropout: a Dropout layer on attn_output_weights.\n",
    "            activation: the activation function of intermediate layer, relu or gelu.\n",
    "            num_layers: the number of sub-encoder-layers in the encoder.\n",
    "            \n",
    "        Input shape:\n",
    "            bs (batch size) x nvars (aka variables, dimensions, channels) x seq_len (aka time steps)\n",
    "            \"\"\"\n",
    "        self.permute = Permute(2, 0, 1)\n",
    "        self.inlinear = nn.Linear(c_in, d_model)\n",
    "        self.relu = nn.ReLU()\n",
    "        encoder_layer = TransformerEncoderLayer(d_model, n_head, dim_feedforward=d_ffn, dropout=dropout, activation=activation)\n",
    "        encoder_norm = nn.LayerNorm(d_model)        \n",
    "        self.transformer_encoder = TransformerEncoder(encoder_layer, n_layers, norm=encoder_norm)\n",
    "        self.transpose = Transpose(1, 0)\n",
    "        self.max = Max(1)\n",
    "        self.outlinear = nn.Linear(d_model, c_out)\n",
    "        \n",
    "    def forward(self,x):\n",
    "        x = self.permute(x) # bs x nvars x seq_len -> seq_len x bs x nvars\n",
    "        x = self.inlinear(x) # seq_len x bs x nvars -> seq_len x bs x d_model\n",
    "        x = self.relu(x)\n",
    "        x = self.transformer_encoder(x)\n",
    "        x = self.transpose(x) # seq_len x bs x d_model -> bs x seq_len x d_model\n",
    "        x = self.max(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.outlinear(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100930\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TransformerModel(\n",
       "  (permute): Permute((2, 0, 1))\n",
       "  (inlinear): Linear(in_features=3, out_features=64, bias=True)\n",
       "  (relu): ReLU()\n",
       "  (transformer_encoder): TransformerEncoder(\n",
       "    (layers): ModuleList(\n",
       "      (0): TransformerEncoderLayer(\n",
       "        (self_attn): MultiheadAttention(\n",
       "          (out_proj): _LinearWithBias(in_features=64, out_features=64, bias=True)\n",
       "        )\n",
       "        (linear1): Linear(in_features=64, out_features=128, bias=True)\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "        (linear2): Linear(in_features=128, out_features=64, bias=True)\n",
       "        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
       "        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
       "        (dropout1): Dropout(p=0.1, inplace=False)\n",
       "        (dropout2): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (1): TransformerEncoderLayer(\n",
       "        (self_attn): MultiheadAttention(\n",
       "          (out_proj): _LinearWithBias(in_features=64, out_features=64, bias=True)\n",
       "        )\n",
       "        (linear1): Linear(in_features=64, out_features=128, bias=True)\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "        (linear2): Linear(in_features=128, out_features=64, bias=True)\n",
       "        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
       "        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
       "        (dropout1): Dropout(p=0.1, inplace=False)\n",
       "        (dropout2): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (2): TransformerEncoderLayer(\n",
       "        (self_attn): MultiheadAttention(\n",
       "          (out_proj): _LinearWithBias(in_features=64, out_features=64, bias=True)\n",
       "        )\n",
       "        (linear1): Linear(in_features=64, out_features=128, bias=True)\n",
       "        (dropout): Dropout(p=0.1, inplace=False)\n",
       "        (linear2): Linear(in_features=128, out_features=64, bias=True)\n",
       "        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
       "        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
       "        (dropout1): Dropout(p=0.1, inplace=False)\n",
       "        (dropout2): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "    )\n",
       "    (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (transpose): Transpose((1, 0))\n",
       "  (max): Max(1, keepdim=False)\n",
       "  (outlinear): Linear(in_features=64, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nvars = 3\n",
    "seq_len = 96\n",
    "c_out = 2\n",
    "xb = torch.rand(bs, nvars, seq_len)\n",
    "\n",
    "\n",
    "model = TransformerModel(nvars, c_out, d_model=64, n_head=1, d_ffn=128, dropout=0.1, activation='gelu', n_layers=3)\n",
    "test_eq(model(xb).shape, [bs, c_out])\n",
    "print(count_parameters(model))\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "out = create_scripts()\n",
    "beep(out)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
